faster mmf
This commit is contained in:
parent
7444a7a18b
commit
250ae9aee8
|
|
@ -993,17 +993,10 @@ namespace ggml_cuda_mma {
|
|||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3)
|
||||
#if 0
|
||||
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
||||
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
|
||||
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
|
||||
}
|
||||
#endif
|
||||
#elif defined(CDNA2) || defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue