CUDA: fix BF16 FA compilation (#20865)

This commit is contained in:
Johannes Gäßler 2026-03-22 17:53:33 +01:00 committed by GitHub
parent 23c9182ce8
commit bd3f1d9d65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 3 deletions

View File

@ -42,11 +42,15 @@ template<typename dst_t, typename src_t>
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#ifdef GGML_USE_HIP
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
#else
#if __CUDA_ARCH__ >= 800
return __bfloat1622float2(x);
#else
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
#endif // __CUDA_ARCH__ >= 800
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP