mirror of https://github.com/google/gemma.cpp.git
Check for HWY_NATIVE_DOT_BF16.
This commit is contained in:
parent
59ebecce22
commit
2829ef17ad
10
gemma/ops.h
10
gemma/ops.h
|
|
@ -149,6 +149,8 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
|
||||||
size_t kCapacity>
|
size_t kCapacity>
|
||||||
HWY_INLINE void MatVecAddLoop(
|
HWY_INLINE void MatVecAddLoop(
|
||||||
|
|
@ -174,6 +176,7 @@ HWY_INLINE void MatVecAddLoop(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||||
|
|
@ -344,6 +347,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||||
|
|
||||||
|
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||||
if constexpr (
|
if constexpr (
|
||||||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
|
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
|
||||||
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
|
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
|
||||||
|
|
@ -352,10 +356,12 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
||||||
detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
|
detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
|
||||||
mat, mat_ofs, vec_dequant.get(), add, out, pool);
|
mat, mat_ofs, vec_dequant.get(), add, out, pool);
|
||||||
} else {
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
|
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
|
||||||
mat, mat_ofs, vec_aligned, add, out, pool);
|
mat, mat_ofs, vec_aligned, add, out, pool);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue