clean up mmf
This commit is contained in:
parent
8c875b23cb
commit
7b43cbc083
|
|
@ -333,7 +333,7 @@ namespace ggml_cuda_mma {
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
if constexpr (I == 16 && J == 8) {
|
if constexpr (I == 16 && J == 8) {
|
||||||
return 4 * (threadIdx.x / 16) + l;
|
return ne * (threadIdx.x / 16) + l;
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -359,7 +359,7 @@ namespace ggml_cuda_mma {
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
if constexpr (I == 16 && J == 8) {
|
if constexpr (I == 16 && J == 8) {
|
||||||
return 2 * (threadIdx.x / 16) + l;
|
return ne * (threadIdx.x / 16) + l;
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
|
||||||
|
|
@ -808,8 +808,8 @@ static void mul_mat_f_switch_cols_per_block(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
|
||||||
template void mul_mat_f_cuda<T, MMF_ROWS_PER_BLOCK, ncols_dst>( \
|
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||||
|
|
@ -819,32 +819,32 @@ static void mul_mat_f_switch_cols_per_block(
|
||||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
#define DECL_MMF_CASE_EXTERN(nrows_dst, ncols_dst) \
|
||||||
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
extern DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
|
||||||
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
extern DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
|
||||||
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
extern DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
|
||||||
|
|
||||||
#define DECL_MMF_CASE(ncols_dst) \
|
#define DECL_MMF_CASE(nrows_dst, ncols_dst) \
|
||||||
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
|
||||||
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
|
||||||
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
|
||||||
|
|
||||||
DECL_MMF_CASE_EXTERN(1);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 1);
|
||||||
DECL_MMF_CASE_EXTERN(2);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 2);
|
||||||
DECL_MMF_CASE_EXTERN(3);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 3);
|
||||||
DECL_MMF_CASE_EXTERN(4);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 4);
|
||||||
DECL_MMF_CASE_EXTERN(5);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 5);
|
||||||
DECL_MMF_CASE_EXTERN(6);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 6);
|
||||||
DECL_MMF_CASE_EXTERN(7);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 7);
|
||||||
DECL_MMF_CASE_EXTERN(8);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 8);
|
||||||
DECL_MMF_CASE_EXTERN(9);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 9);
|
||||||
DECL_MMF_CASE_EXTERN(10);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 10);
|
||||||
DECL_MMF_CASE_EXTERN(11);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 11);
|
||||||
DECL_MMF_CASE_EXTERN(12);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 12);
|
||||||
DECL_MMF_CASE_EXTERN(13);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 13);
|
||||||
DECL_MMF_CASE_EXTERN(14);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 14);
|
||||||
DECL_MMF_CASE_EXTERN(15);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 15);
|
||||||
DECL_MMF_CASE_EXTERN(16);
|
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 16);
|
||||||
#else
|
#else
|
||||||
#define DECL_MMF_CASE(ncols_dst)
|
#define DECL_MMF_CASE(MMF_ROWS_PER_BLOCK, ncols_dst)
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue