Check for HWY_NATIVE_DOT_BF16.

This commit is contained in:
Sam Kaufman 2024-04-30 15:19:28 -07:00
parent 59ebecce22
commit 2829ef17ad
1 changed files with 9 additions and 3 deletions

View File

@ -149,6 +149,8 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
} }
} }
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT, template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
size_t kCapacity> size_t kCapacity>
HWY_INLINE void MatVecAddLoop( HWY_INLINE void MatVecAddLoop(
@ -174,6 +176,7 @@ HWY_INLINE void MatVecAddLoop(
} }
} }
} }
#endif
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT> template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
@ -344,6 +347,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>(); constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip; constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
if constexpr ( if constexpr (
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>() && hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
@ -352,10 +356,12 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
detail::MatVecAddInner<true, kAdd, kOuter, kInner>( detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_dequant.get(), add, out, pool); mat, mat_ofs, vec_dequant.get(), add, out, pool);
} else { return;
}
#endif
detail::MatVecAddInner<false, kAdd, kOuter, kInner>( detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_aligned, add, out, pool); mat, mat_ofs, vec_aligned, add, out, pool);
}
} }
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT> template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>