enable mmf for rdna4

This commit is contained in:
zhang hui 2025-11-22 14:56:20 +08:00
parent 028f93ef98
commit 698c9f2418
2 changed files with 16 additions and 1 deletions

View File

@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (src1_ncols > 16) {
return false;
}
}

View File

@ -8,6 +8,19 @@ using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
// TODO: submit a bug to rocm compiler, remove this when the bug is fixed.
// force rocm compiler to use more register and unroll code for mul_mat_f.
#if defined(RDNA4)
#define MMF_REGISTER_UNROLL_FOR_RDNA \
do { \
if (blockIdx.z == -1) { \
NO_DEVICE_CODE; \
} \
} while(0)
#else
#define MMF_REGISTER_UNROLL_FOR_RDNA
#endif // defined(RDNA4)
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
const int32_t * ids_dst_compact = nullptr;
@ -153,6 +166,7 @@ static __global__ void mul_mat_f(
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
MMF_REGISTER_UNROLL_FOR_RDNA;
}
}
@ -191,6 +205,7 @@ static __global__ void mul_mat_f(
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
MMF_REGISTER_UNROLL_FOR_RDNA;
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);