diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 35735d48b2..f19defbff9 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,7 +63,7 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); -#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; @@ -135,7 +135,7 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; -#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 const _Float16 * K_h_f16 = reinterpret_cast(K_h); const _Float16 * V_h_f16 = reinterpret_cast(V_h); _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);