Added check for dst_t to cuda_cast template for float
This commit is contained in:
parent
7fd898beac
commit
caa8fba0cc
|
|
@ -62,8 +62,8 @@ template<typename dst_t, typename src_t>
|
||||||
#else
|
#else
|
||||||
return {x.x, x.y};
|
return {x.x, x.y};
|
||||||
#endif // GGML_USE_HIP
|
#endif // GGML_USE_HIP
|
||||||
} else if constexpr (std::is_same_v<src_t, ue4m3>) {
|
} else if constexpr (std::is_same_v<src_t, ue4m3> && std::is_same_v<dst_t, float>) {
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && defined(CUDART_VERSION) && CUDART_VERSION >= 12050 // This matches cuda_fp8.h's version gate.
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && defined(CUDART_VERSION) && CUDART_VERSION >= 12050 // This matches cuda_fp8.h's version gate.
|
||||||
// This uses the same fp8 conversion that __nv_fp8_e4m3 uses internally.
|
// This uses the same fp8 conversion that __nv_fp8_e4m3 uses internally.
|
||||||
__half h = __half(__nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t) x.x, __NV_E4M3));
|
__half h = __half(__nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t) x.x, __NV_E4M3));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue