Added check for dst_t to cuda_cast template for float
This commit is contained in:
parent
7fd898beac
commit
caa8fba0cc
|
|
@ -62,8 +62,8 @@ template<typename dst_t, typename src_t>
|
|||
#else
|
||||
return {x.x, x.y};
|
||||
#endif // GGML_USE_HIP
|
||||
} else if constexpr (std::is_same_v<src_t, ue4m3>) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
} else if constexpr (std::is_same_v<src_t, ue4m3> && std::is_same_v<dst_t, float>) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && defined(CUDART_VERSION) && CUDART_VERSION >= 12050 // This matches cuda_fp8.h's version gate.
|
||||
// This uses the same fp8 conversion that __nv_fp8_e4m3 uses internally.
|
||||
__half h = __half(__nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t) x.x, __NV_E4M3));
|
||||
|
|
|
|||
Loading…
Reference in New Issue