clean up mmf

This commit is contained in:
zhang hui 2026-01-17 11:48:54 +08:00
parent 8c875b23cb
commit 7b43cbc083
2 changed files with 29 additions and 29 deletions

View File

@ -333,7 +333,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) { if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l; return ne * (threadIdx.x / 16) + l;
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -359,7 +359,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) { if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l; return ne * (threadIdx.x / 16) + l;
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;

View File

@ -808,8 +808,8 @@ static void mul_mat_f_switch_cols_per_block(
} }
} }
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ #define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
template void mul_mat_f_cuda<T, MMF_ROWS_PER_BLOCK, ncols_dst>( \ template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \ const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \ const int64_t stride_col_id, const int64_t stride_row_id, \
@ -819,32 +819,32 @@ static void mul_mat_f_switch_cols_per_block(
cudaStream_t stream, const mmf_ids_data * ids_data); cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \ #define DECL_MMF_CASE_EXTERN(nrows_dst, ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ extern DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ extern DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) extern DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \ #define DECL_MMF_CASE(nrows_dst, ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \ DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \ DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst)
DECL_MMF_CASE_EXTERN(1); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 1);
DECL_MMF_CASE_EXTERN(2); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 2);
DECL_MMF_CASE_EXTERN(3); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 3);
DECL_MMF_CASE_EXTERN(4); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 4);
DECL_MMF_CASE_EXTERN(5); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 5);
DECL_MMF_CASE_EXTERN(6); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 6);
DECL_MMF_CASE_EXTERN(7); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 7);
DECL_MMF_CASE_EXTERN(8); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 8);
DECL_MMF_CASE_EXTERN(9); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 9);
DECL_MMF_CASE_EXTERN(10); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 10);
DECL_MMF_CASE_EXTERN(11); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 11);
DECL_MMF_CASE_EXTERN(12); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 12);
DECL_MMF_CASE_EXTERN(13); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 13);
DECL_MMF_CASE_EXTERN(14); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 14);
DECL_MMF_CASE_EXTERN(15); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 15);
DECL_MMF_CASE_EXTERN(16); DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 16);
#else #else
#define DECL_MMF_CASE(ncols_dst) #define DECL_MMF_CASE(MMF_ROWS_PER_BLOCK, ncols_dst)
#endif #endif