clean up mmf
This commit is contained in:
parent
8c875b23cb
commit
7b43cbc083
|
|
@ -333,7 +333,7 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -359,7 +359,7 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x / 16) + l;
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
|
|||
|
|
@ -808,8 +808,8 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
}
|
||||
}
|
||||
|
||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, MMF_ROWS_PER_BLOCK, ncols_dst>( \
|
||||
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
|
|
@ -819,32 +819,32 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
#define DECL_MMF_CASE_EXTERN(nrows_dst, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
|
||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
|
||||
|
||||
#define DECL_MMF_CASE(ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||
#define DECL_MMF_CASE(nrows_dst, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
|
||||
DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
|
||||
|
||||
DECL_MMF_CASE_EXTERN(1);
|
||||
DECL_MMF_CASE_EXTERN(2);
|
||||
DECL_MMF_CASE_EXTERN(3);
|
||||
DECL_MMF_CASE_EXTERN(4);
|
||||
DECL_MMF_CASE_EXTERN(5);
|
||||
DECL_MMF_CASE_EXTERN(6);
|
||||
DECL_MMF_CASE_EXTERN(7);
|
||||
DECL_MMF_CASE_EXTERN(8);
|
||||
DECL_MMF_CASE_EXTERN(9);
|
||||
DECL_MMF_CASE_EXTERN(10);
|
||||
DECL_MMF_CASE_EXTERN(11);
|
||||
DECL_MMF_CASE_EXTERN(12);
|
||||
DECL_MMF_CASE_EXTERN(13);
|
||||
DECL_MMF_CASE_EXTERN(14);
|
||||
DECL_MMF_CASE_EXTERN(15);
|
||||
DECL_MMF_CASE_EXTERN(16);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 1);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 2);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 3);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 4);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 5);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 6);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 7);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 8);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 9);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 10);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 11);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 12);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 13);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 14);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 15);
|
||||
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 16);
|
||||
#else
|
||||
#define DECL_MMF_CASE(ncols_dst)
|
||||
#define DECL_MMF_CASE(MMF_ROWS_PER_BLOCK, ncols_dst)
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue