From 2829ef17ad84e19c797786247f28a32c0863c1c2 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Tue, 30 Apr 2024 15:19:28 -0700 Subject: [PATCH] Check for HWY_NATIVE_DOT_BF16. --- gemma/ops.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 2b5dc39..1988bd8 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -149,6 +149,8 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } + +#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 template HWY_INLINE void MatVecAddLoop( @@ -174,6 +176,7 @@ HWY_INLINE void MatVecAddLoop( } } } +#endif template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, @@ -344,6 +347,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 if constexpr ( CompressTraits::kSupportsEvenOdd && hwy::IsSameEither() @@ -352,10 +356,12 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); detail::MatVecAddInner( mat, mat_ofs, vec_dequant.get(), add, out, pool); - } else { - detail::MatVecAddInner( - mat, mat_ofs, vec_aligned, add, out, pool); + return; } + #endif + + detail::MatVecAddInner( + mat, mat_ofs, vec_aligned, add, out, pool); } template