mirror of https://github.com/google/gemma.cpp.git
Check for HWY_NATIVE_DOT_BF16.
This commit is contained in:
parent
59ebecce22
commit
2829ef17ad
12
gemma/ops.h
12
gemma/ops.h
|
|
@ -149,6 +149,8 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
|
||||
size_t kCapacity>
|
||||
HWY_INLINE void MatVecAddLoop(
|
||||
|
|
@ -174,6 +176,7 @@ HWY_INLINE void MatVecAddLoop(
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
|
|
@ -344,6 +347,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
||||
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||
if constexpr (
|
||||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
|
||||
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
|
||||
|
|
@ -352,10 +356,12 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
||||
detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_dequant.get(), add, out, pool);
|
||||
} else {
|
||||
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, add, out, pool);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, add, out, pool);
|
||||
}
|
||||
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
|
|
|
|||
Loading…
Reference in New Issue