diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index fc94873896..560d8e2742 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -772,6 +772,36 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + +#elif defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + false, + a_vec[0], + false, + b_vec[0], + acc[0], + false + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + false, + a_vec[1], + false, + b_vec[1], + acc[0], + false + ); +#endif // defined(RDNA4) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -799,21 +829,6 @@ namespace ggml_cuda_mma { 0, 0, 0); #endif // defined(CDNA3) -#elif defined(AMD_WMMA_AVAILABLE) - using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; - int32x16_t * acc = (int32x16_t *) D.x; - -#if defined(RDNA4) - acc[0] = __builtin_amdgcn_wmma_i32_32x32x16_i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_wmma_i32_32x32x16_i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); -#endif // defined(RDNA4) - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE;