Added check for dst_t to cuda_cast template for float

This commit is contained in:
Michael Wand 2026-03-24 02:04:55 -07:00
parent 7fd898beac
commit caa8fba0cc
1 changed files with 2 additions and 2 deletions

View File

@ -62,8 +62,8 @@ template<typename dst_t, typename src_t>
#else
return {x.x, x.y};
#endif // GGML_USE_HIP
} else if constexpr (std::is_same_v<src_t, ue4m3>) {
#if defined(__CUDA_ARCH__)
} else if constexpr (std::is_same_v<src_t, ue4m3> && std::is_same_v<dst_t, float>) {
#if defined(__CUDA_ARCH__)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && defined(CUDART_VERSION) && CUDART_VERSION >= 12050 // This matches cuda_fp8.h's version gate.
// This uses the same fp8 conversion that __nv_fp8_e4m3 uses internally.
__half h = __half(__nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t) x.x, __NV_E4M3));