clean up mmf

This commit is contained in:
zhang hui 2026-01-17 11:48:54 +08:00
parent 8c875b23cb
commit 7b43cbc083
2 changed files with 29 additions and 29 deletions

View File

@ -333,7 +333,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
@ -359,7 +359,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l;
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;

View File

@ -808,8 +808,8 @@ static void mul_mat_f_switch_cols_per_block(
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, MMF_ROWS_PER_BLOCK, ncols_dst>( \
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
@ -819,32 +819,32 @@ static void mul_mat_f_switch_cols_per_block(
cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
#define DECL_MMF_CASE_EXTERN(nrows_dst, ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
#define DECL_MMF_CASE(nrows_dst, ncols_dst) \
DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);
DECL_MMF_CASE_EXTERN(3);
DECL_MMF_CASE_EXTERN(4);
DECL_MMF_CASE_EXTERN(5);
DECL_MMF_CASE_EXTERN(6);
DECL_MMF_CASE_EXTERN(7);
DECL_MMF_CASE_EXTERN(8);
DECL_MMF_CASE_EXTERN(9);
DECL_MMF_CASE_EXTERN(10);
DECL_MMF_CASE_EXTERN(11);
DECL_MMF_CASE_EXTERN(12);
DECL_MMF_CASE_EXTERN(13);
DECL_MMF_CASE_EXTERN(14);
DECL_MMF_CASE_EXTERN(15);
DECL_MMF_CASE_EXTERN(16);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 1);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 2);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 3);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 4);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 5);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 6);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 7);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 8);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 9);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 10);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 11);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 12);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 13);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 14);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 15);
DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 16);
#else
#define DECL_MMF_CASE(ncols_dst)
#define DECL_MMF_CASE(MMF_ROWS_PER_BLOCK, ncols_dst)
#endif