diff --git a/gemma/ops.h b/gemma/ops.h index 2b5dc39..1988bd8 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -149,6 +149,8 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } + +#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 template HWY_INLINE void MatVecAddLoop( @@ -174,6 +176,7 @@ HWY_INLINE void MatVecAddLoop( } } } +#endif template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, @@ -344,6 +347,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 if constexpr ( CompressTraits::kSupportsEvenOdd && hwy::IsSameEither() @@ -352,10 +356,12 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); detail::MatVecAddInner( mat, mat_ofs, vec_dequant.get(), add, out, pool); - } else { - detail::MatVecAddInner( - mat, mat_ofs, vec_aligned, add, out, pool); + return; } + #endif + + detail::MatVecAddInner( + mat, mat_ofs, vec_aligned, add, out, pool); } template