HIP: add mmf for CDNA (#18896)
* refactor mmf rows_per_block
* speed up compile
* pass cdna compile
* fix cuda error
* clean up mmf
* f32 mmf
* clean float mma
* fix mmf error
* faster mmf
* extend tile k
* fix compile error
* Revert "extend tile k"
This reverts commit 4d2ef3d483.
* fix smem overflow
* speed up compiling mmf
* speed up compile for hip
* 512 block for cdna
* config pad size
* fix as comment
* update select logic
* move some code to cuh
* fix as comment
* correct cdna3 config
---------
Co-authored-by: zhang hui <you@example.com>
This commit is contained in:
parent
eed25bc6b0
commit
f3dd7b8e68
|
|
@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 64;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
|
|||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
||||
}
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
|
|
@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
|
|||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
|
||||
#ifdef AMD_MFMA_AVAILABLE
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3)
|
||||
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
||||
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
|
||||
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
|
|
@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
|
|||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
|
|
@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
|
|||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3) || defined(CDNA2)
|
||||
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
|
||||
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
|
||||
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#elif defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
|
||||
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
|
||||
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
#endif // defined(CDNA3) || defined(CDNA2)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,13 @@
|
|||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return MMF_ROWS_PER_BLOCK_CDNA;
|
||||
} else {
|
||||
return MMF_ROWS_PER_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
|
|
@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
ids_info_ptr = &ids_info;
|
||||
}
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int rows_per_block = mmf_get_rows_per_block(cc);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
constexpr int vals_per_T = 1;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<float>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<half2>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
mul_mat_f_switch_rows_per_block<nv_bfloat162>(
|
||||
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
|
|
@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||
if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
} else {
|
||||
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
|
||||
return false;
|
||||
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
|
||||
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
|
||||
return false;
|
||||
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
|
||||
return false;
|
||||
} else if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,31 @@
|
|||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
#define MMF_ROWS_PER_BLOCK_CDNA 64
|
||||
|
||||
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return 512;
|
||||
} else {
|
||||
return 256;
|
||||
}
|
||||
}
|
||||
|
||||
static __forceinline__ int mmf_get_padding(int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return 2;
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __device__ int mmf_get_padding() {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
return 2;
|
||||
#else
|
||||
return 4;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
|
|
@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
|
|||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
|
@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
|
|||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int tile_k_padded = warp_size + mmf_get_padding();
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
|
|
@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
|
|||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
|
|
@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
|
|||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
float sum[rows_per_block/warp_size] = {0.0f};
|
||||
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
sum[i1] += buf_iw[j*kiw + i];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!has_ids) {
|
||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
} else {
|
||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif //VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
|
|
@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
|
|||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
|
|
@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
|
|||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
|
@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
|
|||
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int tile_k_padded = warp_size + mmf_get_padding();
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
|
|
@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
|
|||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
|
|
@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
|
|||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
float sum[rows_per_block/warp_size] = {0.0f};
|
||||
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
sum[i1] += buf_iw[j * kiw + i];
|
||||
}
|
||||
}
|
||||
|
||||
const int global_j = col_base + j;
|
||||
|
|
@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
|
|||
const int token = (int) qrm.x;
|
||||
if (token < ncols_dst_total) {
|
||||
const int slot = (int) qrm.y;
|
||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
||||
|
|
@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
|
|||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||
|
||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
|
|
@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
|
|||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int cols_per_block>
|
||||
template <typename T, int rows_per_block, int cols_per_block>
|
||||
void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
|
|
@ -605,7 +645,7 @@ void mul_mat_f_cuda(
|
|||
|
||||
int64_t nwarps_best = 1;
|
||||
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
||||
int64_t max_block_size = 256;
|
||||
int64_t max_block_size = mmf_get_max_block_size(cc);
|
||||
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
||||
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
||||
if (niter < niter_best) {
|
||||
|
|
@ -614,10 +654,9 @@ void mul_mat_f_cuda(
|
|||
}
|
||||
}
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
|
||||
const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
|
|
@ -628,56 +667,56 @@ void mul_mat_f_cuda(
|
|||
|
||||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
|
|
@ -691,7 +730,7 @@ void mul_mat_f_cuda(
|
|||
GGML_UNUSED_VARS(nchannels_y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int rows_per_block>
|
||||
static void mul_mat_f_switch_cols_per_block(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
|
|
@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
|
||||
switch (ncols_case) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
|
|
@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||
template <typename T>
|
||||
static void mul_mat_f_switch_rows_per_block(
|
||||
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
switch (rows_per_block) {
|
||||
case MMF_ROWS_PER_BLOCK: {
|
||||
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
|
||||
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case MMF_ROWS_PER_BLOCK_CDNA: {
|
||||
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
|
||||
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
|
|
@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
|
||||
|
||||
#define DECL_MMF_CASE(ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
|
||||
|
||||
DECL_MMF_CASE_EXTERN(1);
|
||||
DECL_MMF_CASE_EXTERN(2);
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
|||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
|
||||
|
|
|
|||
Loading…
Reference in New Issue