enable mmf for rdna4
This commit is contained in:
parent
028f93ef98
commit
698c9f2418
|
|
@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
return false;
|
||||
}
|
||||
} else {
|
||||
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,19 @@ using namespace ggml_cuda_mma;
|
|||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
|
||||
// TODO: submit a bug to rocm compiler, remove this when the bug is fixed.
|
||||
// force rocm compiler to use more register and unroll code for mul_mat_f.
|
||||
#if defined(RDNA4)
|
||||
#define MMF_REGISTER_UNROLL_FOR_RDNA \
|
||||
do { \
|
||||
if (blockIdx.z == -1) { \
|
||||
NO_DEVICE_CODE; \
|
||||
} \
|
||||
} while(0)
|
||||
#else
|
||||
#define MMF_REGISTER_UNROLL_FOR_RDNA
|
||||
#endif // defined(RDNA4)
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
const int32_t * ids_dst_compact = nullptr;
|
||||
|
|
@ -153,6 +166,7 @@ static __global__ void mul_mat_f(
|
|||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||
MMF_REGISTER_UNROLL_FOR_RDNA;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -191,6 +205,7 @@ static __global__ void mul_mat_f(
|
|||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
MMF_REGISTER_UNROLL_FOR_RDNA;
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
|
|
|
|||
Loading…
Reference in New Issue