diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 3efda0021d..323173b90e 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -333,7 +333,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return 4 * (threadIdx.x / 16) + l; + return ne * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; return -1; @@ -359,7 +359,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return 2 * (threadIdx.x / 16) + l; + return ne * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; return -1; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 1a203f2c76..43fc4f6b4b 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -808,8 +808,8 @@ static void mul_mat_f_switch_cols_per_block( } } -#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ - template void mul_mat_f_cuda( \ +#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \ + template void mul_mat_f_cuda( \ const T * x, const float * y, const int32_t * ids, float * dst, \ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t stride_col_id, const int64_t stride_row_id, \ @@ -819,32 +819,32 @@ static void mul_mat_f_switch_cols_per_block( cudaStream_t stream, const mmf_ids_data * ids_data); #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -#define DECL_MMF_CASE_EXTERN(ncols_dst) \ - extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) +#define DECL_MMF_CASE_EXTERN(nrows_dst, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst) -#define DECL_MMF_CASE(ncols_dst) \ - DECL_MMF_CASE_HELPER(float, ncols_dst) \ - DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) +#define DECL_MMF_CASE(nrows_dst, ncols_dst) \ + DECL_MMF_CASE_HELPER(float, nrows_dst, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, nrows_dst, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, nrows_dst, ncols_dst) -DECL_MMF_CASE_EXTERN(1); -DECL_MMF_CASE_EXTERN(2); -DECL_MMF_CASE_EXTERN(3); -DECL_MMF_CASE_EXTERN(4); -DECL_MMF_CASE_EXTERN(5); -DECL_MMF_CASE_EXTERN(6); -DECL_MMF_CASE_EXTERN(7); -DECL_MMF_CASE_EXTERN(8); -DECL_MMF_CASE_EXTERN(9); -DECL_MMF_CASE_EXTERN(10); -DECL_MMF_CASE_EXTERN(11); -DECL_MMF_CASE_EXTERN(12); -DECL_MMF_CASE_EXTERN(13); -DECL_MMF_CASE_EXTERN(14); -DECL_MMF_CASE_EXTERN(15); -DECL_MMF_CASE_EXTERN(16); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 1); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 2); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 3); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 4); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 5); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 6); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 7); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 8); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 9); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 10); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 11); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 12); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 13); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 14); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 15); +DECL_MMF_CASE_EXTERN(MMF_ROWS_PER_BLOCK, 16); #else -#define DECL_MMF_CASE(ncols_dst) +#define DECL_MMF_CASE(MMF_ROWS_PER_BLOCK, ncols_dst) #endif